#!/usr/bin/env python
# -*- coding: utf-8 -*-

import operator
from src.manager.log_manager import LogManager
from src.manager.file_manager import FileManager
from math import log

Logger = LogManager.get_logger(__name__)


class TreeManager:
    """
    决策树
    """

    def __init__(self):
        pass

    def calculate_shannon_entropy(self, data_set):
        """
        计算香农熵
        :param data_set:
        :return:
        """
        line_number = len(data_set)
        label_counts = {}
        for feat_vector in data_set:
            current_label = feat_vector[-1]
            if current_label not in label_counts.keys():
                label_counts[current_label] = 0
            label_counts[current_label] += 1
        shannon_entropy = 0.0
        for key in label_counts:
            prob = float(label_counts[key]) / line_number
            shannon_entropy -= prob * log(prob, 2)
        return shannon_entropy

    def split_data_set(self, data_set, axis, value):
        """
        按照给定特征划分数据集
        :param data_set:
        :param axis: 表示第几个特征值，从0开始
        :param value: 特征值的取值
        :return:
        """
        return_data_set = []
        for feat_vector in data_set:
            if feat_vector[axis] == value:
                reduced_feat_vector = feat_vector[:axis]
                reduced_feat_vector.extend(feat_vector[axis + 1:])
                return_data_set.append(reduced_feat_vector)
        return return_data_set

    def choose_best_feature_to_split(self, data_set):
        """
        选择最好的特征值，以便划分数据集
        :param data_set:
        :return:
        """
        features_number = len(data_set[0]) - 1
        base_entropy = self.calculate_shannon_entropy(data_set)
        best_info_gain = 0.0
        best_feature = -1
        for i in range(features_number):
            feat_list = [example[i] for example in data_set]
            unique_values = set(feat_list)
            new_entropy = 0.0
            for value in unique_values:
                sub_data_set = self.split_data_set(data_set, i, value)
                prob = len(sub_data_set) / float(len(data_set))
                new_entropy += prob * self.calculate_shannon_entropy(sub_data_set)
            info_gain = base_entropy - new_entropy
            if info_gain > best_info_gain:
                best_info_gain = info_gain
                best_feature = i
        return best_feature

    def majority_count(self, class_list):
        """
        返回出现次数最多的分类名称
        :param class_list:
        :return:
        """
        class_count = {}
        for vote in class_list:
            if vote not in class_count.keys():
                class_count[vote] = 0
                class_count[vote] += 1
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_count[0][0]

    def create_tree(self, data_set, labels):
        """
        创建决策树
        :param data_set:
        :param labels:
        :return:
        """
        class_list = [example[-1] for example in data_set]
        if class_list.count(class_list[0]) == len(class_list):
            return class_list[0]
        if len(data_set[0]) == 1:
            return self.majority_count(class_list)
        best_feat = self.choose_best_feature_to_split(data_set)
        best_feat_label = labels[best_feat]
        my_tree = {best_feat_label: {}}
        del (labels[best_feat])
        feat_values = [example[best_feat] for example in data_set]
        unique_values = set(feat_values)
        for value in unique_values:
            sub_labels = labels[:]
            my_tree[best_feat_label][value] = self.create_tree(self.split_data_set(data_set, best_feat, value),
                                                               sub_labels)
        return my_tree

    def classify(self, input_tree, feat_labels, test_vector):
        """
        使用决策树进行分类
        :param input_tree:
        :param feat_labels:
        :param test_vector:
        :return:
        """
        first_str = list(input_tree.keys())[0]
        second_dict = input_tree[first_str]
        feat_index = feat_labels.index(first_str)
        class_label = None
        for key in second_dict.keys():
            if test_vector[feat_index] == key:
                if type(second_dict[key]).__name__ == 'dict':
                    class_label = self.classify(second_dict[key], feat_labels, test_vector)
                else:
                    class_label = second_dict[key]
        return class_label

    def write_tree(self, input_tree, filename):
        """
        存储决策树
        :param input_tree:
        :param filename:
        :return:
        """
        # file_manager = FileManager()
        # file_manager.write(filename, 'wb', input_tree)
        import pickle
        fw = open(filename, 'wb')
        pickle.dump(input_tree, fw)
        fw.close()

    def read_tree(self, filename):
        """
        从文件中提取决策树
        :param filename:
        :return:
        """
        # file_manager = FileManager()
        # return file_manager.read(file_manager, 'rb')
        import pickle
        fr = open(filename, 'rb+')
        return pickle.load(fr)
